Step 1: Segmentation mask is generated by using ground truth bounding box coordinates. So, that segmentation algorithm can be trained end-to-end.
Step 2: Image augmentation techniques like random crop, random scaling, random rotation, random shear etc are applied. Identical image augmentation is applied to both input image and the mask image.
Step 3: Fully Convolutional Network is used for segmentation. Pre-trained VGG_11 with batch normalization is used as an encoder network. It's weights are not updated.
Step 4: Output from the network is mask image. To get bounding boxes mask image is processed as follows:
Mask is converted into Binary Image by the threshold of 0.5 and Opening operation is applied with filter of size 5*5.
Breadth First Search is applied to find out all Connected Components. A connected component represents a text region.
Bounding Box of connected components are calculated.
from IPython.display import Image
Image(filename=r"doc_images/input_image.png")
Image(filename=r"doc_images/generated_mask.png")
Image(filename=r"doc_images/binary_image.png")
Image(filename=r"doc_images/opening.png")
Image(filename=r"doc_images/connected_components.png")
Image(filename=r"doc_images/bounding_box.png")
#Importing all required moduls
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision import transforms
import torchvision.transforms.functional as TF
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import cv2
import os
import random
import re
import copy
#If GPU available set device to GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Device: ',device)
class TD_Dataset(Dataset):
hight = 224
width = 224
#Constructor
def __init__(self,image_path=r'images',gt_path=r'ground_truth',augmentation=False):
self.image_path = image_path
self.gt_path = gt_path
self.images_names = os.listdir(self.image_path)
self.length = len(self.images_names)
self.augmentation = augmentation
self.normalize = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
# len() function implementaion
def __len__(self):
return self.length
#Method which applies image augmentation to both input image and mask image
@staticmethod
def segmentation_transforms(image, mask):
#Randomly change the brightness, contrast, saturation and hue of an image. Only for input image
image=torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1)(image)
#Randomly convert image to grayscale. Only for input image
image = transforms.RandomGrayscale(0.05)(image)
#Random Croping and Resizing to original size
i, j, h, w = transforms.RandomCrop.get_params(image, output_size=(200, 200))
image = TF.crop(image, i, j, h, w)
mask = TF.crop(mask, i, j, h, w)
image = TF.resize(image, (TD_Dataset.hight,TD_Dataset.width))
mask = TF.resize(mask, (TD_Dataset.hight,TD_Dataset.width))
#Horizontal and Vertical Flip with probability of 0.5
if random.random() > 0.5:
image = TF.hflip(image)
mask = TF.hflip(mask)
if random.random() > 0.5:
image = TF.vflip(image)
mask = TF.vflip(mask)
#Random affine transformation keeping center invariant
if random.random() > 0.5:
angle = random.randint(-20, 20)
t_h,t_v = random.randint(-20,20),random.randint(-20,20)
shear = random.randint(-20,20)
scale = random.randint(8,12)/10
image = TF.affine(image, angle, (t_h,t_v), scale, shear)
mask = TF.affine(mask, angle, (t_h,t_v), scale, shear)
return image, mask
#Implementing indexing ([])
def __getitem__(self, index):
#Get image name
input_image_name = self.images_names[index]
#Get image
input_image = Image.open(os.path.join(self.image_path,input_image_name))
#Get image size
input_image_size = input_image.size
#Resize image
input_image = input_image.resize((TD_Dataset.hight,TD_Dataset.width))
#Get ground turth file path
output_file_path = os.path.join(self.gt_path,'gt_'+input_image_name[:-4]+'.txt')
#Get output mask and original boxes
output_image,original_boxes,resized_boxes = TD_Dataset.give_output_image(output_file_path,input_image_size)
if self.augmentation==True:
#Get augmented image and mask
input_image,output_image = TD_Dataset.segmentation_transforms(input_image,output_image)
#Convert input image to tensor and normalize image
input_image = self.normalize(input_image)
#Convert mask to tensor
output_image = torch.from_numpy(np.array(output_image,dtype='uint8'))
return (input_image,output_image,input_image_size,original_boxes,resized_boxes)
#Takes ground truth file name,image size. Returns image mask
@staticmethod
def give_output_image(path,input_size):
#Read the file
text = open(path).read()
#Collect each box in list
boxes = [x.strip().split(' ')[:-1] for x in text.strip().split('\n')]
#Convert the string value into integer(co-ordinates)
for i in range(len(boxes)):
boxes[i] = [TD_Dataset.give_number(j) for j in boxes[i]]
original_boxes = copy.deepcopy(boxes)
#Resize original_boxes of original image to boxes in the image of shape TD_Dataset.width , TD_Dataset.hight
for i in range(len(boxes)):
boxes[i] = TD_Dataset.resize_boxes(boxes[i],input_size,(TD_Dataset.width,TD_Dataset.hight))
#Create Mask image
image = np.zeros((TD_Dataset.hight,TD_Dataset.width),dtype='uint8')
for box in boxes:
image[box[1]:box[3],box[0]:box[2]] = 1
return Image.fromarray(image,'L'),original_boxes,boxes
#Convert the bounding boxes co-ordinates for different image sizes
@staticmethod
def resize_boxes(box,initial_size,final_size):
box[0] = int(final_size[0]*(box[0]/initial_size[0]))
box[1] = int(final_size[1]*(box[1]/initial_size[1]))
box[2] = int(final_size[0]*(box[2]/initial_size[0]))
box[3] = int(final_size[1]*(box[3]/initial_size[1]))
return box
#Takes string and returns number from it
@staticmethod
def give_number(string):
string = string.strip()
string = re.sub("\D","",string)
return int(string)
#Train data with no augmentation
train_data = TD_Dataset(augmentation=False)
#Train data and loader with augmentation
aug_train_data = TD_Dataset(augmentation=True)
aug_train_dataloader = DataLoader(train_data, batch_size=24, shuffle=True, num_workers=8)
#Code for inspecting augmented images and masks
#Does inverse of input image normalization so that it can be plotted using matplotlib
def inverse_preprocess(image):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
image[:,0,:,:] *= std[0]
image[:,0,:,:] += mean[0]
image[:,1,:,:] *= std[1]
image[:,1,:,:] += mean[1]
image[:,2,:,:] *= std[2]
image[:,2,:,:] += mean[2]
return image
#Plots some augmented images and their corrosponding masks
N = 10
#Loop over some random sample
for i in random.sample(range(0,len(train_data)),N):
data = aug_train_data[i]
inveresed_image = inverse_preprocess(data[0].view(1,3,TD_Dataset.hight,TD_Dataset.width))
figure,axis = plt.subplots(1, 2)
figure.set_figheight(10)
figure.set_figwidth(10)
axis[0].imshow(inveresed_image.view(3,TD_Dataset.hight,TD_Dataset.width).permute(1,2,0))
axis[0].title.set_text('Input Image')
axis[1].imshow(data[1].view(TD_Dataset.hight,TD_Dataset.width),cmap='gray')
axis[1].title.set_text('Mask')
axis[0].axis('off')
axis[1].axis('off')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
#VGG feature extractor of FCN.
class FCN_VGG(nn.Module):
def __init__(self):
super().__init__()
self.model=torchvision.models.vgg11_bn(pretrained=True)
#We don't need fully connected part of VGG, so removing it
del self.model.classifier
for parameter in self.model.parameters():
parameter.requires_grad=False
def forward(self,x):
outputs=[]
#If image is forwarded then FCN_VGG return outputs of every maxpooling layer
#First Maxpooling layer number in VGG :3
#Second Maxpooling layer number in VGG :7
#Third Maxpooling layer number in VGG :14
#Fourth Maxpooling layer number in VGG :21
#Fifth Maxpooling layer number in VGG :28
for layer_number in [3,7,14,21,28]:
t=x
for j in range(layer_number+1):
t=self.model.features[j](t)
outputs.append(t)
return outputs
#FCN
class FCN(nn.Module):
def __init__(self):
super().__init__()
#Feature extractor VGG
self.vgg = FCN_VGG()
#Upsampling layers
self.transconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.batch_norm1 = nn.BatchNorm2d(512)
self.transconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.batch_norm2 = nn.BatchNorm2d(256)
self.transconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.batch_norm3 = nn.BatchNorm2d(128)
self.transconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.batch_norm4 = nn.BatchNorm2d(64)
self.transconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.batch_norm5 = nn.BatchNorm2d(32)
self.conv1 = nn.Conv2d(32, 1, kernel_size=1)
def forward(self, x):
output = self.vgg(x)
# output[0] shape: (batch_size, 64, 112, 112)
# output[1] shape: (batch_size, 128, 56, 56)
# output[2] shape: (batch_size, 256, 28, 28)
# output[3] shape: (batch_size, 512, 14, 14)
# output[4] shape: (batch_size, 512, 7, 7)
x = self.batch_norm1(F.relu(self.transconv1(output[4]))) # output_shape: (batch_size, 512, 14, 14)
x = x + output[3] # Skip-Connection
x=nn.Dropout2d(0.2)(x) # 2D channel Dropout
x = self.batch_norm2(F.relu(self.transconv2(x))) # output_shape: (batch_size, 256, 28, 28)
x = x + output[2] # Skip-Connection
x=nn.Dropout2d(0.2)(x) # 2D channel Dropout
x = self.batch_norm3(F.relu(self.transconv3(x))) # output_shape: (batch_size, 128, 56, 56)
x = x + output[1] # Skip-Connection
x=nn.Dropout2d(0.2)(x) # 2D channel Dropout
x = self.batch_norm4(F.relu(self.transconv4(x))) # output_shape: (batch_size, 64, 112, 112)
x = x + output[0] # Skip-Connection
x=nn.Dropout2d(0.2)(x) # 2D channel Dropout
x = self.batch_norm5(F.relu(self.transconv5(x))) # output_shape: (batch_size, 32, 224, 224)
x = self.conv1(x) # output_shape: (batch_size, 1, 224, 224)
x = torch.sigmoid(x)
return x
net=FCN()
#Convert net to cuda if available
net.to(device)
#Function which checks if index is inside the image
def is_inside(index,size):
if index[0]>=0 and index[0]<size[0]:
if index[1]>=0 and index[1]<size[1]:
return True
return False
#Function for Breath First Search to find all connected-component(text region) and returns their bounding box co-ordinates
def give_box(image,visited,index):
#Initilize queue with initial node
queue=[index]
visited[index]=True
#Initial maximum and minimum x and y co-ordinates by initial node
max_x,max_y=index
min_x,min_y=index
while(len(queue)>0):
ind=queue.pop(0)
#Update maximum and minimum x and y co-ordinates
if ind[0]<min_x:
min_x=ind[0]
if ind[0]>max_x:
max_x=ind[0]
if ind[1]<min_y:
min_y=ind[1]
if ind[1]>max_y:
max_y=ind[1]
#Explores and 4 directions from the node
for change in [(0,-1),(-1,0),(0,1),(1,0)]:
new_index=(ind[0]+change[0],ind[1]+change[1])
#Append if new node of connected-component is found
if is_inside(new_index,image.shape) and image[new_index]==1 and visited[new_index]==False:
queue.append(new_index)
visited[new_index]=True
return [min_y,min_x,max_y,max_x]
def give_boxes(image):
#This function does following things:
#->Binarization of the generated mask image.
#->Applies Opening(erosion followed by dilation) on the generated mask image to remove noise and to seperate close connected components(text region).
#->Visulazie the image as an undirected graph where edge exist only between adjacent pixels with value one. Thus, text region(regions of ones) is represented by
# connected component in that graph then apply Breath First Search(BFS) to find out all the connected-components in the graph and returns their bounding
# box co-ordinates.
image=image.numpy()
image=(image>0.5).astype('uint8')
kernel = np.ones((5,5),np.uint8)
image = cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel)
boxes=[]
#Initilize all node as unvisited(False)
visited=np.zeros(image.shape,dtype='bool')
#Loop over all nodes if unvisited do Breath First Search and collect bounding box co-ordinates
for index,x in np.ndenumerate(image):
if x==1 and visited[index]==False:
box=give_box(image,visited,index)
boxes.append(box)
return boxes
#Function which draws boundig boxes around texts
def draw_boxes(boxes,image,width,color):
boxed_image=image.clone()
for box in boxes:
draw_box(box,boxed_image,width,color)
return boxed_image
#Function which draws bounding box around text
def draw_box(box,image,width,color):
# RBG Image has
# 0 channel: Red
# 1 channel: Green
# 2 channel: Blue
Color={'red':0,'green':1,'blue':2}
#Remove color
image[:,box[1]-width:box[1]+width,box[0]:box[2]]=0
image[:,box[1]:box[3],box[0]-width:box[0]+width]=0
image[:,box[3]-width:box[3]+width,box[0]:box[2]]=0
image[:,box[1]:box[3],box[2]-width:box[2]+width]=0
#Fill in color
image[Color[color],box[1]-width:box[1]+width,box[0]:box[2]]=1
image[Color[color],box[1]:box[3],box[0]-width:box[0]+width]=1
image[Color[color],box[3]-width:box[3]+width,box[0]:box[2]]=1
image[Color[color],box[1]:box[3],box[2]-width:box[2]+width]=1
def inspect(N=3):
#Convert net to evaluation mode
net.eval()
#Make subplot
figure,axis = plt.subplots(N, 3)
figure.set_figheight(N*5+2)
figure.set_figwidth(15)
with torch.no_grad():
for i in range(N):
#Randomly sample index
index=random.randint(0, len(train_data)-1)
data=train_data[index]
image=data[0].view(1,3,TD_Dataset.hight,TD_Dataset.width)
#Generate mask of image
output=net(image.to(device)).detach().to('cpu').view(TD_Dataset.hight,TD_Dataset.width)
#Inverse preprocess image for plotting
org_image=inverse_preprocess(image).view(3,TD_Dataset.hight,TD_Dataset.width)
#Plot images,mask and bounding boxes
axis[i,0].imshow(org_image.permute(1,2,0))
axis[i,0].title.set_text('Input Image')
axis[i,1].imshow(output,cmap='gray')
axis[i,1].title.set_text('Generated Mask')
#Plot ground truth bounding box
boxed_image=draw_boxes(data[-1],org_image,1,'green')
#Plot predicted bounding box
boxed_image=draw_boxes(give_boxes(output),boxed_image,1,'red')
axis[i,2].imshow(boxed_image.permute(1,2,0))
axis[i,2].title.set_text('Input Image with bounding boxes')
axis[i,2].text(0, 8, 'Predicted', bbox={'facecolor': 'Red', 'pad': 0})
axis[i,2].text(0, 20, 'Ground Truth', bbox={'facecolor': 'Green', 'pad': 0})
axis[i,0].axis('off')
axis[i,1].axis('off')
axis[i,2].axis('off')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
#Optimizer for traning network
optimizer = optim.Adam(net.parameters(), lr=0.001)
#Function to train network
def train_net(epochs=20):
for e in range(1,epochs+1):
print('-------------------------Epoch Number: {}---------------------'.format(e))
avg_loss=0
#As inspect convert model to eval mode every epoch.Convert model to back to traning mode
net.train()
for i,data in enumerate(aug_train_dataloader,1):
#Clear gradients
optimizer.zero_grad()
X=data[0].to(device)
y=data[1].to(device)
y=y.to(dtype=torch.float)
output=net(X)
output=output.view(output.shape[0],output.shape[2],output.shape[3])
loss_fn=nn.BCELoss()
loss=loss_fn(output,y)
loss.backward()
avg_loss+=loss.item()
#Update the weights
optimizer.step()
print('Average Loss :',avg_loss/i)
if e%10==0:
inspect()
train_net(120)
class Un_seen_Dataset(Dataset):
def __init__(self, image_path=r'Unseen'):
self.image_path = image_path
self.images_names = os.listdir(self.image_path)
self.length = len(self.images_names)
self.normalize = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
def __len__(self):
return self.length
def __getitem__(self, index):
input_image_name = self.images_names[index]
input_image=Image.open(os.path.join(self.image_path,input_image_name))
input_image_size=input_image.size
input_image=input_image.resize((TD_Dataset.hight,TD_Dataset.width))
input_image=self.normalize(input_image)
return (input_image,input_image_size)
Un_seen_data=Un_seen_Dataset()
Un_seen_dataloader = DataLoader(Un_seen_data, batch_size=1, shuffle=True, num_workers=0)
#Convert net to evaluation mode
net.eval()
with torch.no_grad():
for i,data in enumerate(Un_seen_dataloader):
#Make subplot
figure,axis = plt.subplots(1, 3)
figure.set_figheight(5)
figure.set_figwidth(15)
image = data[0].view(1,3,TD_Dataset.hight,TD_Dataset.width)
#Generate Mask
output = net(image.to(device)).detach().to('cpu').view(TD_Dataset.hight,TD_Dataset.width)
#Inverse of normalization for plotting
org_image = inverse_preprocess(image).view(3,TD_Dataset.hight,TD_Dataset.width)
#Plot images,mask and bounding boxes
axis[0].imshow(org_image.permute(1,2,0))
axis[0].title.set_text('Input Image')
axis[1].imshow(output,cmap='gray')
axis[1].title.set_text('Generated mask')
#Plot predicted bounding box
boxed_image=draw_boxes(give_boxes(output),org_image,1,'red')
axis[2].imshow(boxed_image.permute(1,2,0))
axis[2].title.set_text('Input Image with bounding boxes')
axis[0].axis('off')
axis[1].axis('off')
axis[2].axis('off')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()